# -*- coding:utf-8 -*-

from utils.log_utils import ClampLog
from model.base.model import Model
from model.predict.load_predict import set_predict_in_model
from model.net.load_net import set_net_in_model

"""
动态加载模型
"""


def model_train():
    """
    训练
    :return:
    """
    model = Model()
    with ClampLog('building net'):
        model = set_net_in_model(model)
    with ClampLog('training'):
        model.train()


def validate_x(x):
    """
    检验预测时输入的x是否匹配
    :param x:
    :return:
    """
    # if not x.shape:
    #     return False
    # idx = -1
    # config_shape = config.xs_shape
    # for size in config_shape[::-1]:  # 倒序遍历
    #     if x.shape[idx] != size:
    #         return False
    #     idx -= 1
    return True


def model_predict(x):
    """
    预测
    :param x:
    :return:
    """
    if validate_x(x):
        model = Model()
        with ClampLog('building net'):
            model = set_net_in_model(model)
        with ClampLog('predicting'):
            model = set_predict_in_model(model)
            y_pred = model.predict(model, x)
        return y_pred
    else:
        raise ValueError('传入的数据不符合网络定义的shape')

